# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Optional
from zipfile import ZipFile

import streamlit as st
from dotenv import find_dotenv, load_dotenv

from src.enums import EnvVars, PromptKeys, StorageIndexVars
from src.graphrag_api import GraphragAPI

"""
This module contains functions that are used across the Streamlit app.
"""


def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
    """
    Initialize the Streamlit app with the necessary configurations.
    """
    # set page configuration
    st.set_page_config(initial_sidebar_state="expanded", layout="wide")

    # set custom CSS
    with open(css_file) as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

    # initialize session state variables
    set_session_state_variables()

    # load environment variables
    _ = load_dotenv(find_dotenv(filename=env_file) or None, override=True)

    # either load from .env file or from session state
    st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv(
        EnvVars.APIM_SUBSCRIPTION_KEY.value,
        st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value],
    )
    st.session_state[EnvVars.DEPLOYMENT_URL.value] = os.getenv(
        EnvVars.DEPLOYMENT_URL.value, st.session_state[EnvVars.DEPLOYMENT_URL.value]
    )
    if (
        st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
        and st.session_state[EnvVars.DEPLOYMENT_URL.value]
    ):
        st.session_state["headers"] = {
            "Ocp-Apim-Subscription-Key": st.session_state[
                EnvVars.APIM_SUBSCRIPTION_KEY.value
            ],
            "Content-Type": "application/json",
        }
        st.session_state["headers_upload"] = {
            "Ocp-Apim-Subscription-Key": st.session_state[
                EnvVars.APIM_SUBSCRIPTION_KEY.value
            ]
        }
        return True
    else:
        return False


def set_session_state_variables() -> None:
    """
    Initalizes most session state variables for the app.
    """
    for key in PromptKeys:
        value = key.value
        if value not in st.session_state:
            st.session_state[value] = ""
    for key in StorageIndexVars:
        value = key.value
        if value not in st.session_state:
            st.session_state[value] = ""
    for key in EnvVars:
        value = key.value
        if value not in st.session_state:
            st.session_state[value] = ""
    if "saved_prompts" not in st.session_state:
        st.session_state["saved_prompts"] = False
    if "initialized" not in st.session_state:
        st.session_state["initialized"] = False
    if "new_upload" not in st.session_state:
        st.session_state["new_upload"] = False


def update_session_state_prompt_vars(
    entity_extract: Optional[str] = None,
    summarize: Optional[str] = None,
    community: Optional[str] = None,
    initial_setting: bool = False,
    prompt_dir: str = "./prompts",
) -> None:
    """
    Updates the session state variables for the LLM prompts.
    """
    if initial_setting:
        entity_extract, summarize, community = get_prompts(prompt_dir)
    if entity_extract:
        st.session_state[PromptKeys.ENTITY.value] = entity_extract
    if summarize:
        st.session_state[PromptKeys.SUMMARY.value] = summarize
    if community:
        st.session_state[PromptKeys.COMMUNITY.value] = community


def generate_and_extract_prompts(
    client: GraphragAPI,
    storage_name: str,
    zip_file_name: str = "prompts.zip",
    limit: int = 5,
) -> None | Exception:
    """
    Makes API call to generate LLM prompts, extracts prompts from zip file,
    and updates the prompt session state variables.
    """
    try:
        client.generate_prompts(
            storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
        )
        _extract_prompts_from_zip(zip_file_name)
        update_session_state_prompt_vars(initial_setting=True)
        return
    except Exception as e:
        return e


def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
    with ZipFile(zip_file_name, "r") as zip_ref:
        zip_ref.extractall()


def open_file(file_path: str | Path):
    with open(file_path, "r", encoding="utf-8") as file:
        text = file.read()
    return text


def zip_directory(directory_path: str, zip_path: str):
    """
    Zips all contents of a directory into a single zip file.

    Parameters:
    - directory_path: str, the path of the directory to zip
    - zip_path: str, the path where the zip file will be created
    """
    root_dir_name = os.path.basename(directory_path.rstrip("/"))
    with ZipFile(zip_path, "w") as zipf:
        for root, _, files in os.walk(directory_path):
            for file in files:
                file_path = os.path.join(root, file)
                relpath = os.path.relpath(file_path, start=directory_path)
                arcname = os.path.join(root_dir_name, relpath)
                zipf.write(file_path, arcname)


def get_prompts(prompt_dir: str = "./prompts"):
    """
    Extract text from generated prompts.  Assumes file names comply with pregenerated file name standards.
    """
    prompt_paths = [
        prompt for prompt in Path(prompt_dir).iterdir() if prompt.name.endswith(".txt")
    ]
    entity_ext_prompt = [
        open_file(path) for path in prompt_paths if path.name.startswith("entity")
    ][0]
    summ_prompt = [
        open_file(path) for path in prompt_paths if path.name.startswith("summ")
    ][0]
    comm_report_prompt = [
        open_file(path) for path in prompt_paths if path.name.startswith("community")
    ][0]
    return entity_ext_prompt, summ_prompt, comm_report_prompt
